Skip to content

[Feature] Mopd (Multi-Teacher On-Policy distillation) supported#2051

Draft
leoyuppieqnew wants to merge 15 commits into
THUDM:mainfrom
leoyuppieqnew:mopd_opensource
Draft

[Feature] Mopd (Multi-Teacher On-Policy distillation) supported#2051
leoyuppieqnew wants to merge 15 commits into
THUDM:mainfrom
leoyuppieqnew:mopd_opensource

Conversation

@leoyuppieqnew

@leoyuppieqnew leoyuppieqnew commented Jun 11, 2026

Copy link
Copy Markdown

feat: Multi-Teacher On-Policy Distillation (MOPD)

Summary

Add Multi-Teacher On-Policy Distillation (MOPD) support to slime, enabling a single student model to distill knowledge from multiple domain-specific teachers simultaneously with importance sampling (IS) correction for stable off-policy training.

34 files changed, +9040 / -50 lines

Motivation

Standard on-policy distillation (OPD) supports only a single teacher. In practice, different domains (math, code, reasoning) benefit from specialized teacher models. MOPD extends OPD to aggregate knowledge from multiple domain-specific teachers into a single student, using per-teacher reverse KL advantages averaged across domains with IS-weight correction for stable training when the student policy diverges from the sampling policy.

Algorithm

MOPD supports three distillation strategies, controlled by --mopd-distill-type:

Token-Level Mode (token_level, default)

Uses the sampled token's log-prob difference as a point estimate of the reverse KL divergence:

reverse_kl_d(y_t) = sg[log π_d(y_t) - log π_θ(y_t)]

Training loss:

w_t = sg[π_θ(y_t) / μ_θ(y_t)]  clipped to [ε_low, ε_high]   # IS weight
Â_MOPD,t = (1/D) Σ_d (reverse_kl_d + α · Â_ORM)              # avg across D teachers
L = -E[1/|y| Σ_t w_t · Â_MOPD,t · log π_θ(y_t)]              # proxy policy loss
  • Memory: Negligible (1 scalar per token per teacher)
  • Accuracy: Underestimates the true KL (only evaluates at sampled positions)
  • Teacher modes: SGLang and Megatron

Top-K Mode (top_k, recommended)

A memory-efficient approximation of the full-vocab KL. Only stores the teacher's top-k logits and indices, with an analytical tail correction:

D_KL(π_θ ∥ π_d) ≈ KL_topk + KL_tail

Top-K part (exact over the teacher's top-k tokens):

KL_topk = Σ_{y ∈ top-k} π_s(y) [log π_s(y) - log π_t(y)]

Tail correction (approximates the remaining vocabulary):

KL_tail ≈ π_s_tail · log(π_s_tail / π_t_tail)

where:

  • π_s_tail = 1 - Σ_{y ∈ top-k} π_s(y) — student's exact tail mass (computed via all-reduce across TP ranks)
  • Teacher tail mass estimation:
    • SGLang mode: π_t_tail = 1 - Σ exp(log_prob_t(y))exact (SGLang returns full-vocab log-probs)
    • Megatron mode: π_t_tail ≈ (V - V_eff) / V — uniform assumption over non-top-k tokens

Full loss:

w_t = sg[π_θ(y_t) / μ_θ(y_t)]  clipped to [ε_low, ε_high]    # IS weight
L_topk_kl = (1/D) Σ_d (1/|y| Σ_t w_t · KL_topk+d(π_θ ∥ π_d))  # IS-corrected approx KL
L = L_topk_kl + α · L_pg                                         # combined with PG loss
  • Memory: B × R × k × 2 × 4B / TP (~97% reduction vs full_vocab for k=1024, V=152K)
  • Accuracy: Very close to full_vocab (top-k tokens capture >99% probability mass)
  • Teacher modes: SGLang and Megatron
  • TP-aware: SGLang returns global token IDs → converted to per-TP-shard local indices with -inf padding for out-of-shard entries; Megatron provides local indices directly

Full-Vocabulary Mode (full_vocab)

Computes the exact reverse KL over the entire vocabulary:

D_KL(π_θ ∥ π_d) = Σ_y π_θ(y) [log π_θ(y) - log π_d(y)]

Training loss:

w_t = sg[π_θ(y_t) / μ_θ(y_t)]  clipped to [ε_low, ε_high]   # IS weight
L_fv_kl = (1/D) Σ_d (1/|y| Σ_t w_t · D_KL(π_θ ∥ π_d))        # IS-corrected KL loss
L = L_fv_kl + α · L_pg                                         # combined with PG loss
  • Memory: Very high — B × R × V × 4B / TP (stores full teacher logits)
  • Accuracy: Exact KL (gold standard)
  • Teacher modes: Megatron only (SGLang cannot efficiently return full-vocab logits)

Comparison

token_level top_k full_vocab
KL accuracy Point estimate Approximate (~99%+) Exact
Teacher data per token 1 scalar k×2 values V values
Memory overhead ≈ 0 O(k) O(V)
Teacher mode SGLang / Megatron SGLang / Megatron Megatron only
Gradient Through policy loss only Through full student softmax Through full student softmax

Key Features

  • Multi-teacher distillation: Aggregate knowledge from multiple domain-specific teachers into a single student
  • Importance sampling (IS) correction: Clipped IS weights w_t = sg[π_θ/μ_θ] ensure stable training
  • ORM combination: Coefficient α blends reverse KL with standard ORM advantages
  • Two teacher modes:
    • sglang — teachers on external SGLang servers (can have different architectures)
    • megatron — teachers loaded into Megatron (same architecture required)
  • Per-sample domain routing: Route samples to specific teachers via mopd_domains metadata field
  • Non-colocate mode: Separate actor training GPUs from SGLang rollout GPUs
  • Qwen3.5 VL MoE support: Bridge-mode weight sync for vision encoders via HfWeightIteratorBridge

Architecture

┌─────────────────────────────────────────────────────────────────┐
│                         Rollout Phase                            │
├─────────────────────────────────────────────────────────────────┤
│  mopd.py::reward_func()                                         │
│    → Query teacher SGLang servers (concurrent, per-domain)      │
│    → Request top_logprobs_num=k per position                    │
│    → Extract top-k logprobs + token indices                     │
│    → Store in sample.mopd_teacher_topk_{logits,indices}         │
│                                                                  │
│  rollout.py::_split_train_data_by_dp()                          │
│    → Partition mopd dict data by DP rank                        │
└─────────────────────────────────────────────────────────────────┘
                              ↓
┌─────────────────────────────────────────────────────────────────┐
│                        Training Phase                            │
├─────────────────────────────────────────────────────────────────┤
│  actor.py::_get_rollout_data()                                  │
│    → Pop mopd_teacher_topk_{logits,indices} dict                │
│    → Convert global token IDs → per-TP-shard local indices      │
│    → Pad out-of-shard entries with -inf logit / 0 index         │
│    → Store as rollout_data["mopd_teacher_{domain}_topk_*"]      │
│                                                                  │
│  model.py::forward_step()                                       │
│    → Dynamic batch_keys (base + per-domain topk keys)           │
│    → get_batch() → DataIterator                                 │
│                                                                  │
│  loss.py::policy_loss_function()                                │
│    → get_responses(apply_temperature=False) → student logits    │
│    → apply_mopd_topk_to_loss()                                  │
│      → vocab_parallel_topk_reverse_kl() [TP-aware]             │
│      → IS-weighted KL loss + optional PG loss                   │
└─────────────────────────────────────────────────────────────────┘
                              ↓
┌─────────────────────────────────────────────────────────────────┐
│                      Weight Sync Phase                           │
├─────────────────────────────────────────────────────────────────┤
│  update_weight_from_distributed.py                              │
│    → if megatron_to_hf_mode == "bridge":                        │
│        _update_weights_bridge() via HfWeightIteratorBridge      │
│    → else:                                                       │
│        _send_weights() via template method pattern              │
└─────────────────────────────────────────────────────────────────┘

New Files

File Description
slime/rollout/mopd.py SGLang-mode MOPD reward function & post-processing (780 lines)
slime/utils/ppo_utils.py (+369) vocab_parallel_topk_reverse_kl, vocab_parallel_reverse_kl
slime_plugins/megatron_bridge/qwen35_vl_moe.py Qwen3.5 VL MoE bridge plugin (1263 lines)
scripts/models/qwen3.5-397B-A17B.sh Model config for 397B-A17B MoE
tools/convert_torch_dist_to_hf_parallel.py Parallel distributed checkpoint conversion
tools/merge_missing_keys.py Merge missing VL encoder keys into checkpoint
tools/patch_attention_gate_on_cluster.py Attention gate patching utility
tests/test_mopd.py Unit tests for MOPD token-level pipeline (555 lines)
tests/test_mopd_full_vocab.py Unit tests for full-vocab distillation (1041 lines)
tests/test_mopd_sglang_topk_pipeline.py Unit tests for SGLang top-k pipeline (747 lines)

Modified Files (key changes)

File Changes
slime/backends/megatron_utils/actor.py (+317) SGLang top-k data preparation: global→local TP index conversion
slime/backends/megatron_utils/loss.py (+829) TopK KL loss, full-vocab KL loss, IS-weighted loss integration
slime/backends/megatron_utils/model.py (+59) Dynamic batch_keys for per-domain MOPD teacher data
slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py (+49) Bridge mode via HfWeightIteratorBridge for VL MoE
slime/utils/arguments.py (+376) MOPD argument definitions, validation, auto-configuration
slime/ray/rollout.py (+86) MOPD dict-typed data partitioning across DP ranks
slime/utils/types.py (+15) messages field on Sample dataclass
slime/utils/data.py (+10) filter_long_prompt multimodal input handling

Example Scripts

Script Setup
run-qwen35-397B-A17B-mopd-topk-sglang.sh 288 GPUs (256 train + 32 rollout), non-colocate, SGLang top-k
run-qwen35-35B-A3B-mopd-topk-sglang.sh 48 GPUs (32 train + 16 rollout), non-colocate, SGLang top-k

Usage

# 1. Start teacher SGLang server
python3 -m sglang.launch_server \
    --model-path /path/to/teacher \
    --host 0.0.0.0 --port 13141 \
    --tp 8 --ep-size 16 \
    --mem-fraction-static 0.7

# 2. Convert student to Megatron format
source scripts/models/qwen3.5-397B-A17B.sh
PYTHONPATH=/root/Megatron-LM torchrun --nproc_per_node=8 \
    tools/convert_hf_to_torch_dist.py \
    ${MODEL_ARGS[@]} \
    --hf-checkpoint /path/to/student \
    --save /path/to/student_torch_dist

# 3. Run MOPD training
export MOPD_TEACHER_URLS='{"math":"http://teacher-host:8300/generate"}'
export MOPD_TEACHERS_JSON='[{"name":"math_teacher","domain":"math"}]'

ray job submit --address="http://127.0.0.1:8265" \
    -- python3 train.py \
    --use-mopd \
    --mopd-teacher-mode sglang \
    --mopd-distill-type top_k \
    --mopd-topk-k 128 \
    --mopd-alpha 0.0 \
    --mopd-eps-low 0.2 \
    --mopd-eps-high 5.0 \
    ${MODEL_ARGS[@]} \
    ...

Key Arguments

Argument Description Default
--use-mopd Enable MOPD (mutually exclusive with --use-opd)
--mopd-teachers JSON list of teacher configs (or env MOPD_TEACHERS_JSON) required
--mopd-teacher-mode sglang or megatron auto-detected
--mopd-distill-type token_level, top_k, or full_vocab token_level
--mopd-topk-k Number of top-k tokens (for top_k mode) 1024
--mopd-alpha ORM combination coefficient (0 = pure distillation) 0.0
--mopd-eps-low IS weight lower clipping bound 0.2
--mopd-eps-high IS weight upper clipping bound 5.0
--mopd-teacher-loads Megatron teacher checkpoint paths (megatron mode)
--mopd-sampling-logprobs-key Key for sampling log-probs in IS computation rollout_log_probs

Training Results

The following figure shows key training metrics from an MOPD top-k distillation run, confirming stable and effective convergence:

training_metrics

Key observations:

  • mopd_topk_kl / mopd_topk_kl/origin / mopd_topk_kl/enhanced: KL divergence between student and teacher decreases steadily, indicating the student successfully absorbs teacher knowledge across domains.
  • mopd_is_weight_mean: Importance sampling weights remain tightly centered around 1.0, validating that IS clipping (eps_low=0.2, eps_high=5.0) effectively prevents variance explosion.
  • train_rollout_logprob_abs_diff: The absolute difference between training and rollout log-probs stays small and stable, confirming on-policy consistency.
  • grad_norm / entropy_loss / loss: Gradient norm stays bounded, entropy decreases gradually, and the overall loss converges — all signs of healthy training dynamics.
  • ppo_kl: Remains well-controlled, indicating the student policy is not drifting excessively from the reference policy.

Testing

python -m pytest tests/test_mopd.py tests/test_mopd_full_vocab.py tests/test_mopd_sglang_topk_pipeline.py -v

Compatibility

  • Requires SGLang with top_logprobs_num support (recent versions)
  • Teacher vocabulary size must match student's padded_vocab_size (for SGLang top-k mode)
  • Bridge mode (--megatron-to-hf-mode bridge) required for Qwen3.5 VL MoE weight sync
  • Non-colocate mode requires separate GPU allocation for actor and rollout

燕一 added 11 commits June 10, 2026 15:16
- Add MOPD rollout module (slime/rollout/mopd.py) for multi-teacher distillation
- Add MOPD loss computation in megatron backend (KL divergence based)
- Add MOPD-related arguments (teacher config, distillation params)
- Add ray rollout integration for MOPD pipeline
- Add example scripts for Qwen3.5-35B-A3B MOPD training
- Add README documentation for MOPD feature
- Add unit tests for MOPD functionality
- Extend MOPD loss computation to support full vocabulary KL divergence
- Add parameterized distillation mode selection (token-level vs full-vocab)
- Add ppo_utils helpers for full vocabulary logits processing
- Modify model.py to support output_all_logits mode
- Add example script for full-vocab megatron training
- Add comprehensive unit tests for full vocabulary distillation
- Update README with full vocabulary distillation documentation
- Implement TopK token selection for efficient distillation loss computation
- Add TopK-related arguments (topk_tokens, topk_temperature)
- Add 397B model startup script (scripts/models/qwen3.5-397B-A17B.sh)
- Extend ppo_utils with TopK logits extraction and processing
- Update full-vocab megatron script with TopK options
- Extend tests for TopK distillation mode
- Add SGLang-based teacher rollout pipeline (separate from Megatron in-process mode)
- Implement HTTP-based teacher logprobs collection during rollout
- Add MOPD teacher URL configuration via environment variables
- Fix logits calculation bug in TopK mode
- Fix bad teacher request handling with retry logic
- Improve MOPD rollout logging and monitoring
- Add 397B model example scripts (megatron and sglang modes)
- Add README_zh.md with Chinese documentation
- Add comprehensive SGLang TopK pipeline integration tests
- Add Qwen3.5 VL MoE megatron bridge plugin (qwen35_vl_moe.py)
- Add multimodal input handling in MOPD rollout pipeline
- Add visual input processing with image token support
- Fix fused experts computation for VL MoE architecture
- Fix VL MoE model conversion (HF <-> torch_dist)
- Add 35B-A3B multimodal TopK SGLang training example script
- Register VL MoE bridge in megatron_bridge plugin
- Add Qwen3.5 MoE bridge conversion support in mbridge plugin
- Add parallel distributed conversion tool (convert_torch_dist_to_hf_parallel.py)
- Add merge_missing_keys.py for handling partial checkpoint merges
- Fix megatron_to_hf conversion for Qwen3.5 architecture
- Fix convert_torch_dist_to_hf_bridge.py quantization support
- Fix loss becoming inf due to numerical instability in KL computation
- Fix padding vocab size handling in actor forward pass
- Add train-memory-margin-bytes argument for memory management
- Add attention gate patching tool for distributed checkpoints
- Add safety checks for logits with padding tokens
- Update 397B SGLang script with stability improvements
- Add non-colocate mode support in update_weight_from_distributed.py
  (separate actor training GPUs from SGLang rollout GPUs)
- Add HfWeightIteratorBridge support for Megatron-to-HF conversion
  in weight update pipeline (supports VL MoE models)
- Switch 397B model script to use bridge mode for megatron-to-hf
- Update 397B SGLang script for non-colocate deployment
- Update 35B script with optimized parallelism settings
- Add GUIDE_qwen35_moe_mopd.md with detailed usage documentation
- Cover MOPD workflow, distillation modes (TopK, full-vocab)
- Document SGLang teacher server setup and configuration
- Document multi-teacher domain routing and hyperparameters
- Include troubleshooting and FAQ sections
- Fix filter_long_prompt to correctly process multimodal inputs when
  apply_chat_template has already converted prompt to a string
- Add 'messages' field to Sample dataclass to preserve raw message list
  for multimodal processing after chat template application
- Ensures vision info (images) can be extracted from original messages
  even when prompt has been templated
Remove megatron-mode and full-vocab example scripts that have
known OOM problems. Keep only the validated SGLang TopK scripts:
- run-qwen35-397B-A17B-mopd-topk-sglang.sh
- run-qwen35-35B-A3B-mopd-topk-sglang.sh
@leoyuppieqnew leoyuppieqnew marked this pull request as draft June 11, 2026 03:02
@leoyuppieqnew leoyuppieqnew marked this pull request as ready for review June 11, 2026 03:13
@leoyuppieqnew leoyuppieqnew marked this pull request as draft June 11, 2026 03:13
燕一 added 4 commits June 11, 2026 11:29
# Conflicts:
#	slime/backends/megatron_utils/data.py
#	slime/backends/megatron_utils/loss.py
#	slime/backends/megatron_utils/model.py
#	slime/backends/megatron_utils/update_weight/update_weight_from_distributed.py
#	slime/ray/rollout.py
#	slime/utils/types.py
- Add strict=False to zip() calls (B905)
- Rename unused loop variable domain to _domain (B007)
- Add from err to raise in except clause (B904)
- Remove unused variables process_group, k, last_file_idx/name (F841)
- Add noqa: F841 for intentionally assigned test variables
- Replace assert False with raise AssertionError (B011)
- Rename ambiguous variable l to v (E741)
- Apply black formatting to all modified files
1. Fix RuntimeError in test_topk_kl_identical_distributions: kl.item()
   fails on multi-element tensor, changed to (kl >= -0.1).all()

2. Add missing vocab_size=20 to TestApplyMopdTopkToLoss._make_args():
   apply_mopd_topk_to_loss requires args.vocab_size which was not set
@leoyuppieqnew

Copy link
Copy Markdown
Author

@zhuzilin Hello!Would you mind helping review it when you’re available? I’d really appreciate it. Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant